{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3d308cac-2783-4fa1-9502-d9e4f85c8c34",
   "metadata": {},
   "source": [
    "https://scikit-learn.org/1.1/modules/ensemble.html#extremely-randomized-trees"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d594e692-7386-431c-a7fb-a7a9588d74bb",
   "metadata": {},
   "source": [
    "## Training and Logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c0d7d7a8-214b-408e-b07b-903c0318639a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((442, 10), (442,))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.datasets import load_diabetes\n",
    "X, y = load_diabetes(return_X_y=True)\n",
    "(X.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9a4827b5-f69d-4d8b-b00d-41c4714e1dc4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(array([ 0.03807591,  0.05068012,  0.06169621,  0.02187239, -0.0442235 ,\n",
       "         -0.03482076, -0.04340085, -0.00259226,  0.01990749, -0.01764613]),\n",
       "  151.0),\n",
       " (array([-0.00188202, -0.04464164, -0.05147406, -0.02632753, -0.00844872,\n",
       "         -0.01916334,  0.07441156, -0.03949338, -0.06833155, -0.09220405]),\n",
       "  75.0)]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[(X[0], y[0]), (X[1], y[1])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6c40aafa-4871-41d7-9676-cf646a3e1f76",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/da/.cache/pants/named_caches/pex_root/venvs/s/378f4125/venv/lib/python3.9/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.\n",
      "  warnings.warn(\"Setuptools is replacing distutils.\")\n",
      "2023/03/29 13:12:33 WARNING mlflow.utils.environment: Failed to resolve installed pip version. ``pip`` will be added to conda.yaml environment spec without a version specifier.\n",
      "Registered model 'da_extra_trees' already exists. Creating a new version of this model...\n",
      "2023/03/29 13:12:33 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_extra_trees, version 2\n",
      "Created version '2' of model 'da_extra_trees'.\n",
      "/home/da/.cache/pants/named_caches/pex_root/venvs/s/378f4125/venv/lib/python3.9/site-packages/liga/mlflow/logger.py:137: UserWarning: value of rikai.output.schema is None or empty and will not be populated to MLflow\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import getpass\n",
    "\n",
    "import mlflow\n",
    "from liga.sklearn.mlflow import log_model\n",
    "from sklearn.ensemble import ExtraTreesRegressor\n",
    "\n",
    "mlflow_tracking_uri = \"sqlite:///mlruns.db\"\n",
    "mlflow.set_tracking_uri(mlflow_tracking_uri)\n",
    "\n",
    "# train a model\n",
    "with mlflow.start_run() as run:\n",
    "    ####\n",
    "    # Part 1: Train the model and register it on MLflow\n",
    "    ####\n",
    "    model = ExtraTreesRegressor(n_estimators=100, random_state=0)\n",
    "    model.fit(X, y)\n",
    "    registered_model_name = f\"{getpass.getuser()}_extra_trees\"\n",
    "    log_model(model, registered_model_name=registered_model_name)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93af5fdb-0b86-4906-8646-3d83ee0658fb",
   "metadata": {},
   "source": [
    "## Apply the model on large scale dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "390feb1c-3d6b-4a89-b8cf-3ee7625ff2c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-03-29 13:12:33,283 INFO Rikai (__init__.py:121): setting spark.sql.extensions to net.xmacs.liga.spark.RikaiSparkSessionExtensions\n",
      "2023-03-29 13:12:33,285 INFO Rikai (__init__.py:121): setting spark.driver.extraJavaOptions to -Dio.netty.tryReflectionSetAccessible=true\n",
      "2023-03-29 13:12:33,286 INFO Rikai (__init__.py:121): setting spark.executor.extraJavaOptions to -Dio.netty.tryReflectionSetAccessible=true\n",
      "2023-03-29 13:12:33,287 INFO Rikai (__init__.py:121): setting spark.jars to https://github.com/komprenilo/liga/releases/download/v0.3.0/liga-spark321-assembly_2.12-0.3.0.jar\n",
      "23/03/29 13:12:34 WARN Utils: Your hostname, tubi resolves to a loopback address: 127.0.1.1; using 192.168.31.32 instead (on interface wlp0s20f3)\n",
      "23/03/29 13:12:34 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "23/03/29 13:12:38 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
      "23/03/29 13:12:39 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n",
      "23/03/29 13:12:39 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.\n",
      "23/03/29 13:12:39 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+------+------------------------+-------+\n",
      "|name            |plugin|uri                     |options|\n",
      "+----------------+------+------------------------+-------+\n",
      "|mlflow_sklearn_m|      |mlflow:///da_extra_trees|       |\n",
      "+----------------+------+------------------------+-------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from example import spark\n",
    "from liga.mlflow import CONF_MLFLOW_TRACKING_URI\n",
    "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"false\")\n",
    "spark.conf.set(CONF_MLFLOW_TRACKING_URI, mlflow_tracking_uri)\n",
    "spark.sql(f\"\"\"\n",
    "CREATE OR REPLACE MODEL mlflow_sklearn_m LOCATION 'mlflow:///{registered_model_name}';\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "spark.sql(\"show models\").show(1, vertical=False, truncate=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8ac2de60-e53c-4b26-b999-4e95b4cdeaf7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- mlflow_sklearn_m: float (nullable = true)\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mlflow_sklearn_m</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>151.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   mlflow_sklearn_m\n",
       "0             151.0"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from liga.numpy.sql import literal\n",
    "\n",
    "result = spark.sql(f\"\"\"\n",
    "select\n",
    "  ML_PREDICT(mlflow_sklearn_m, {literal(X[0])})\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "result.printSchema()\n",
    "result.toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cb9cf20d-59ee-460b-baf4-3c6d391a4fe9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>mlflow_sklearn_m</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>75.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   mlflow_sklearn_m\n",
       "0              75.0"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spark.sql(f\"\"\"\n",
    "select  ML_PREDICT(mlflow_sklearn_m, {literal(X[1])})\n",
    "\"\"\").toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "612f3dc3-934c-4805-9d56-5e439a378384",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
